{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# [SubsAI](https://github.com/abdeladim-s/subsai): Voice Activity Detection (VAD) example\n",
    "\n",
    "If you have any issues, questions or suggestions, post a new issue [here](https://github.com/abdeladim-s/subsai/issues) or create a new discussion [here](https://github.com/abdeladim-s/subsai/discussions)\n",
    "\n",
    "## This notebook shows how to use subsai with a VAD, in this example we have used [silero-vad](https://github.com/snakers4/silero-vad).\n",
    "\n",
    "* ### The idea is to cut long media files into chunks using the VAD and process them one by one.\n",
    "* ### The progress is saved with each processed chunk, so you can interrupt the process whenever you want."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!apt install ffmpeg\n",
    "!pip install jedi\n",
    "!pip install git+https://github.com/abdeladim-s/subsai.git\n",
    "!pip install -q torchaudio"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Importing the dtw module. When using in academic works please cite:\n",
      "  T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.\n",
      "  J. Stat. Soft., doi:10.18637/jss.v031.i07.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import torch\n",
    "from pathlib import Path\n",
    "import pysubs2\n",
    "from pysubs2 import SSAFile\n",
    "import pickle\n",
    "from subsai import SubsAI"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Global vars"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "SAMPLING_RATE = 16000\n",
    "media_file = '../assets/video/test0.webm'\n",
    "subs_file = media_file + '.srt'\n",
    "chunks_object_path = media_file + '-chunks.pkl'\n",
    "transcription_model = 'ggerganov/whisper.cpp'\n",
    "transcription_configs = {'model_type': 'base'}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load silero"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
    "                              model='silero_vad',\n",
    "                              force_reload=True,\n",
    "                              onnx=False)\n",
    "\n",
    "(get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks) = utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Transcription model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subs_ai = SubsAI()\n",
    "tr_model = subs_ai.create_model(transcription_model, transcription_configs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Magic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading stored chunks object ...\n",
      "Loading subtitles file ../assets/video/test0.webm.srt...\n",
      "Done :)\n"
     ]
    }
   ],
   "source": [
    "wav = read_audio(media_file, sampling_rate=SAMPLING_RATE)\n",
    "if Path(chunks_object_path).exists():\n",
    "    print(\"Loading stored chunks object ...\")\n",
    "    with open(chunks_object_path, 'rb') as chunk_file:\n",
    "        speech_timestamps, chunk_idx_start = pickle.load(chunk_file)\n",
    "else:\n",
    "    print(\"No chunks object found, running silero ...\")\n",
    "    speech_timestamps = get_speech_timestamps(wav, model, return_seconds=False, sampling_rate=SAMPLING_RATE)\n",
    "    chunk_idx_start = -1\n",
    "\n",
    "# load subtitles file if it already exists\n",
    "if Path(subs_file).exists():\n",
    "    print(f\"Loading subtitles file {subs_file}...\")\n",
    "    subs = pysubs2.load(subs_file)\n",
    "else:\n",
    "    subs = SSAFile()\n",
    "\n",
    "\n",
    "for chunk_idx in range(chunk_idx_start+1, len(speech_timestamps)):\n",
    "    try:\n",
    "        ts = speech_timestamps[chunk_idx]\n",
    "        print(f\"Processing chunk {chunk_idx}...\")\n",
    "        start = ts['start']\n",
    "        end = ts['end']\n",
    "        # save file for transcription\n",
    "        chunk_file = media_file + f\"-chunk-{start}-{end}.wav\"\n",
    "        save_audio(chunk_file, collect_chunks(speech_timestamps[chunk_idx: chunk_idx+1], wav), sampling_rate=SAMPLING_RATE)\n",
    "        # transcribe\n",
    "        chunk_subs = subs_ai.transcribe(chunk_file, tr_model)\n",
    "        # update time\n",
    "        start_time =  round(start / SAMPLING_RATE, 1)\n",
    "        for sub in chunk_subs:\n",
    "            sub.start += start_time * 1000\n",
    "            sub.end += start_time * 1000\n",
    "        # update global_subs\n",
    "        if not subs:\n",
    "            subs = chunk_subs\n",
    "        else:\n",
    "            for sub in chunk_subs:\n",
    "                subs.append(sub)\n",
    "        # clean\n",
    "        os.remove(chunk_file)\n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "    finally:\n",
    "        # save subs file\n",
    "        print(f\"Saving subtitles file {subs_file} ...\")\n",
    "        subs.save(subs_file)\n",
    "        # save chunk_object\n",
    "        print(f\"Saving chunks object ...\")\n",
    "        with open(chunks_object_path, 'wb') as chunk_file:\n",
    "            # Pickle dictionary using protocol 0.\n",
    "            pickle.dump((speech_timestamps, chunk_idx), chunk_file)\n",
    "os.remove(chunks_object_path)\n",
    "print(\"Done :)\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
